Temperature Diagnostics

Once you copy this repository, feel free to delete this notebook!

Imports

# Display output of plots directly in Notebook
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

import intake
import numpy as np
import pandas as pd
import xarray as xr
import ast
from ncar_jobqueue import NCARCluster
from distributed import Client

Spin up a Cluster

cluster = NCARCluster(memory='10 GB')
cluster.scale(20)
client = Client(cluster)
client

Client

Client-8f11c72f-1196-11ec-9cd4-3cecef1a526c

Connection method: Cluster object Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mgrover/proxy/8787/status

Cluster Info

Data Ingest

# Read in the data using xarray or some other package
data_catalog = intake.open_esm_datastore('data/silver-linings-aws-year1.json', csv_kwargs={"converters": {"variables": ast.literal_eval}},)

Subset for 2m Temperature

data_subset = data_catalog.search(frequency='month_1')

Read in the dictionary of datasets

dsets = data_subset.to_dataset_dict(cdf_kwargs={'chunks':{'time':-1}})
--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.stream.case'
100.00% [1/1 00:00<00:00]
ds = xr.open_dataset(data_subset.df.path.values[0])
ds = dsets[list(dsets.keys())[0]]
ds
<xarray.Dataset>
Dimensions:                  (lat: 192, lon: 288, zlon: 1, nbnd: 2, lev: 70, ilev: 71, time: 12)
Coordinates:
  * lat                      (lat) float64 -90.0 -89.06 -88.12 ... 89.06 90.0
  * lon                      (lon) float64 0.0 1.25 2.5 ... 356.2 357.5 358.8
  * zlon                     (zlon) float64 0.0
  * lev                      (lev) float64 5.96e-06 9.827e-06 ... 976.3 992.6
  * ilev                     (ilev) float64 4.5e-06 7.42e-06 ... 985.1 1e+03
  * time                     (time) object 2035-02-01 00:00:00 ... 2036-01-01...
Dimensions without coordinates: nbnd
Data variables: (12/848)
    gw                       (lat) float64 dask.array<chunksize=(192,), meta=np.ndarray>
    zlon_bnds                (zlon, nbnd) float64 dask.array<chunksize=(1, 2), meta=np.ndarray>
    hyam                     (lev) float64 dask.array<chunksize=(70,), meta=np.ndarray>
    hybm                     (lev) float64 dask.array<chunksize=(70,), meta=np.ndarray>
    P0                       float64 ...
    hyai                     (ilev) float64 dask.array<chunksize=(71,), meta=np.ndarray>
    ...                       ...
    soa5_c1SFWET             (time, lat, lon) float32 dask.array<chunksize=(1, 192, 288), meta=np.ndarray>
    soa5_c2                  (time, lev, lat, lon) float32 dask.array<chunksize=(1, 70, 192, 288), meta=np.ndarray>
    soa5_c2DDF               (time, lat, lon) float32 dask.array<chunksize=(1, 192, 288), meta=np.ndarray>
    soa5_c2SFWET             (time, lat, lon) float32 dask.array<chunksize=(1, 192, 288), meta=np.ndarray>
    wet_deposition_NHx_as_N  (time, lat, lon) float32 dask.array<chunksize=(1, 192, 288), meta=np.ndarray>
    wet_deposition_NOy_as_N  (time, lat, lon) float32 dask.array<chunksize=(1, 192, 288), meta=np.ndarray>
Attributes:
    host:                    
    intake_esm_varname:      ['date', 'datesec', 'date_written', 'time_writte...
    source:                  CAM
    Conventions:             CF-1.0
    initial_file:            b.e21.BWSSP245cmip6.f09_g17.CMIP6-SSP2-4.5-WACCM...
    logname:                 geostrat
    topography_file:         /scratch/geostrat/inputdata/atm/cam/topo/fv_0.9x...
    time_period_freq:        month_1
    model_doi_url:           https://doi.org/10.5065/D67H1H0V
    case:                    b.e21.BW.f09_g17.SSP245-TSMLT-GAUSS-LOWER-0.5.001
    intake_esm_dataset_key:  atm.cam.h0.b.e21.BW.f09_g17.SSP245-TSMLT-GAUSS-L...

Data Operation

First, we set up a few helper functions…

def area_grid(lat, lon):
    """
    Calculate the area of each grid cell
    Area is in square meters
    
    Input
    -----------
    lat: vector of latitude in degrees
    lon: vector of longitude in degrees
    
    Output
    -----------
    area: grid-cell area in square-meters with dimensions, [lat,lon]
    
    Notes
    -----------
    Based on the function in
    https://github.com/chadagreene/CDT/blob/master/cdt/cdtarea.m
    """
    from numpy import meshgrid, deg2rad, gradient, cos
    from xarray import DataArray

    xlon, ylat = meshgrid(lon, lat)
    R = earth_radius(ylat)

    dlat = deg2rad(gradient(ylat, axis=0))
    dlon = deg2rad(gradient(xlon, axis=1))

    dy = dlat * R
    dx = dlon * R * cos(deg2rad(ylat))

    area = dy * dx

    xda = DataArray(
        area,
        dims=["lat", "lon"],
        coords={"lat": lat, "lon": lon},
        attrs={
            "long_name": "area_per_pixel",
            "description": "area per pixel",
            "units": "m^2",
        },
    )
    return xda
def earth_radius(lat):
    '''
    calculate radius of Earth assuming oblate spheroid
    defined by WGS84
    
    Input
    ---------
    lat: vector or latitudes in degrees  
    
    Output
    ----------
    r: vector of radius in meters
    
    Notes
    -----------
    WGS84: https://earth-info.nga.mil/GandG/publications/tr8350.2/tr8350.2-a/Chapter%203.pdf
    '''
    from numpy import deg2rad, sin, cos

    # define oblate spheroid from WGS84
    a = 6378137
    b = 6356752.3142
    e2 = 1 - (b**2/a**2)
    
    # convert from geodecic to geocentric
    # see equation 3-110 in WGS84
    lat = deg2rad(lat)
    lat_gc = np.arctan( (1-e2)*np.tan(lat) )

    # radius equation
    # see equation 3-107 in WGS84
    r = (
        (a * (1 - e2)**0.5) 
         / (1 - (e2 * np.cos(lat_gc)**2))**0.5 
        )

    return r
def center_time(ds):
    """make time the center of the time bounds"""
    ds = ds.copy()
    attrs = ds.time.attrs
    encoding = ds.time.encoding
    
    try:
        tb_name, tb_dim = _get_tb_name_and_tb_dim(ds)
        ds['time'] = ds[tb_name].compute().mean(tb_dim).squeeze()
        attrs['note'] = f'time recomputed as {tb_name}.mean({tb_dim})'
        
    except AssertionError:
        print('Using default time values')
    
    ds.time.attrs = attrs
    ds.time.encoding = encoding
    return ds

def _get_tb_name_and_tb_dim(ds):
    """return the name of the time 'bounds' variable and its second dimension"""
    assert 'bounds' in ds.time.attrs, 'missing "bounds" attr on time'
    tb_name = ds.time.attrs['bounds']        
    assert tb_name in ds, f'missing "{tb_name}"'    
    tb_dim = ds[tb_name].dims[-1]
    return tb_name, tb_dim
def calc_area_weighted_mean(ds, resample=True, sample_freq='AS'):
    ds = center_time(ds.sortby('time'))
    # Do some sort of calculation on the data
    ds_out = (
        (ds.resample(time=sample_freq).mean("time") * da_area).sum(dim=("lat", "lon"))
    ) / total_area
    return ds_out

def convert_to_df(ds):
    return ds.TREFHT.to_series().unstack().T

Compute the area for the weights

# area dataArray
da_area = area_grid(ds['lat'], ds['lon'])

# total area
total_area = da_area.sum(['lat','lon'])

Setup which variables to average

variables = ['TREFHT']

Run the computation on each dataset

xr.set_options(keep_attrs=True)
ds_list = []
for key in dsets.keys():
    ds = dsets[key]
    mean = calc_area_weighted_mean(ds, '1M')
    out = mean[variables]
    out.attrs['intake_esm_varname'] = variables
    out.attrs['case'] = ds.case
    ds_list.append(out)

Here we add additional case information

cases = []
for ds in ds_list:
    cases.append(ds.case)
merged_ds = xr.concat(ds_list, dim='case')
merged_ds['case'] = cases
merged_ds.persist()
<xarray.Dataset>
Dimensions:  (case: 1, time: 12)
Coordinates:
  * time     (time) object 2035-01-31 00:00:00 ... 2035-12-31 00:00:00
  * case     (case) <U49 'b.e21.BW.f09_g17.SSP245-TSMLT-GAUSS-LOWER-0.5.001'
Data variables:
    TREFHT   (case, time) float64 dask.array<chunksize=(1, 1), meta=np.ndarray>
Attributes:
    intake_esm_varname:  ['TREFHT']
    case:                b.e21.BW.f09_g17.SSP245-TSMLT-GAUSS-LOWER-0.5.001

Use datetime index instead of cftime

datetimeindex = merged_ds.indexes['time'].to_datetimeindex()
merged_ds['time'] = datetimeindex

Save the dataset

merged_ds.to_netcdf('data/global_average_temperature_aws_year1.nc')